#!/usr/bin/env python3

import os
from datasets import load_dataset

def main():
    ds = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split="train")
    
    ds = ds.shuffle(seed=42)
    test_ds = ds.select(range(1000))
    train_ds = ds.select(range(1000, len(ds)))
    
    out_dir = "dataset/countdown"
    os.makedirs(out_dir, exist_ok=True)
    
    test_path = os.path.join(out_dir, "test.jsonl")
    train_path = os.path.join(out_dir, "train.jsonl")
    
    test_ds.to_json(test_path, orient="records", lines=True)
    train_ds.to_json(train_path, orient="records", lines=True)
    
    print(f"Saved {len(test_ds)} examples to {test_path}")
    print(f"Saved {len(train_ds)} examples to {train_path}")

if __name__ == "__main__":
    main()
